{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 6857.02it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 4297.68it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 874865.84it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 194728.90it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 108324.49it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 650665.49it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 765411.22it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 234263/234263 [00:00<00:00, 2593163.08it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8161.95it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 18841/18841 [00:04<00:00, 4133.19it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 195422.81it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 262159.47it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 193099.20it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 636493.22it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 786763.45it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 133174.47it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 106807.41it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=40, remove_stop_words=True)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 8129.53it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 3435.25it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 125753.56it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 116059.22it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 97521.61it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 125387.18it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 369783.27it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 69981.55it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 9195.61it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 5578.82it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4034.12it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 203912.56it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 203033.10it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 214423.19it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 296112.61it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 677260.54it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 96229.43it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 107222.32it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 40)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             4963800     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "matching_layer_1 (MatchingLayer (None, 10, 40, 1)    0           embedding[0][0]                  \n",
      "                                                                 embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_1 (Conv2D)               (None, 10, 40, 16)   160         matching_layer_1[0][0]           \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_2 (Conv2D)               (None, 10, 40, 32)   4640        conv2d_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dpool_index (InputLayer)        (None, 10, 40, 2)    0                                            \n",
      "__________________________________________________________________________________________________\n",
      "dynamic_pooling_layer_1 (Dynami (None, 3, 10, 32)    0           conv2d_2[0][0]                   \n",
      "                                                                 dpool_index[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "flatten_1 (Flatten)             (None, 960)          0           dynamic_pooling_layer_1[0][0]    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 960)          0           flatten_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 1)            961         dropout_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 4,969,561\n",
      "Trainable params: 4,969,561\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.MatchPyramid()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 300\n",
    "model.params['embedding_trainable'] = True\n",
    "model.params['num_blocks'] = 2\n",
    "model.params['kernel_count'] = [16, 32]\n",
    "model.params['kernel_size'] = [[3, 3], [3, 3]]\n",
    "model.params['dpool_size'] = [3, 10]\n",
    "model.params['optimizer'] = 'adam'\n",
    "model.params['dropout_rate'] = 0.1\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=300)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.DPoolPairDataGenerator(train_pack_processed,\n",
    "                                            fixed_length_left=10,\n",
    "                                            fixed_length_right=40,\n",
    "                                            num_dup=2,\n",
    "                                            num_neg=1,\n",
    "                                            batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "118"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_generator = mz.DPoolDataGenerator(predict_pack_processed,\n",
    "                                          fixed_length_left=10,\n",
    "                                          fixed_length_right=40,\n",
    "                                          batch_size=20)\n",
    "len(predict_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_generator[:]\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "102/102 [==============================] - 6s 60ms/step - loss: 0.7690\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5692250932507502 - normalized_discounted_cumulative_gain@5(0.0): 0.6195008780055599 - mean_average_precision(0.0): 0.5879674081412779\n",
      "Epoch 2/20\n",
      "102/102 [==============================] - 7s 73ms/step - loss: 0.4052\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5577580076131701 - normalized_discounted_cumulative_gain@5(0.0): 0.6208268248908941 - mean_average_precision(0.0): 0.5803710108112564\n",
      "Epoch 3/20\n",
      "102/102 [==============================] - 8s 82ms/step - loss: 0.2382\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5374999946557393 - normalized_discounted_cumulative_gain@5(0.0): 0.5998451250432698 - mean_average_precision(0.0): 0.5588963866653741\n",
      "Epoch 4/20\n",
      "102/102 [==============================] - 5s 54ms/step - loss: 0.1180\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5785198780057831 - normalized_discounted_cumulative_gain@5(0.0): 0.6407731712306632 - mean_average_precision(0.0): 0.5962996584645877\n",
      "Epoch 5/20\n",
      "102/102 [==============================] - 5s 53ms/step - loss: 0.0645\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.56564221992718 - normalized_discounted_cumulative_gain@5(0.0): 0.6216871292635173 - mean_average_precision(0.0): 0.5829349126026341\n",
      "Epoch 6/20\n",
      "102/102 [==============================] - 5s 54ms/step - loss: 0.0285\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5298863790926747 - normalized_discounted_cumulative_gain@5(0.0): 0.5959624363658788 - mean_average_precision(0.0): 0.5525931628401359\n",
      "Epoch 7/20\n",
      "102/102 [==============================] - 6s 54ms/step - loss: 0.0211\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5294486727252763 - normalized_discounted_cumulative_gain@5(0.0): 0.6015588031467002 - mean_average_precision(0.0): 0.5593813397817251\n",
      "Epoch 8/20\n",
      "102/102 [==============================] - 6s 58ms/step - loss: 0.0195\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5295443052661813 - normalized_discounted_cumulative_gain@5(0.0): 0.5925552472657848 - mean_average_precision(0.0): 0.5505374343190799\n",
      "Epoch 9/20\n",
      "102/102 [==============================] - 5s 52ms/step - loss: 0.0108\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5415795026826955 - normalized_discounted_cumulative_gain@5(0.0): 0.6094848815320798 - mean_average_precision(0.0): 0.5632085027021736\n",
      "Epoch 10/20\n",
      "102/102 [==============================] - 6s 59ms/step - loss: 0.0067\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.537292156609087 - normalized_discounted_cumulative_gain@5(0.0): 0.5970539498351208 - mean_average_precision(0.0): 0.5470889590976838\n",
      "Epoch 11/20\n",
      "102/102 [==============================] - 6s 54ms/step - loss: 0.0054\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5483017787175871 - normalized_discounted_cumulative_gain@5(0.0): 0.6200705453459617 - mean_average_precision(0.0): 0.5745993018936056\n",
      "Epoch 12/20\n",
      "102/102 [==============================] - 6s 56ms/step - loss: 0.0098\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5440587013588177 - normalized_discounted_cumulative_gain@5(0.0): 0.6059175976290668 - mean_average_precision(0.0): 0.5580739893844884\n",
      "Epoch 13/20\n",
      "102/102 [==============================] - 6s 58ms/step - loss: 0.0023\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5593334311001811 - normalized_discounted_cumulative_gain@5(0.0): 0.6175112142313822 - mean_average_precision(0.0): 0.5743072357787548\n",
      "Epoch 14/20\n",
      "102/102 [==============================] - 6s 57ms/step - loss: 0.0044\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5492339945971083 - normalized_discounted_cumulative_gain@5(0.0): 0.6122065841195945 - mean_average_precision(0.0): 0.5712218041352198\n",
      "Epoch 15/20\n",
      "102/102 [==============================] - 6s 55ms/step - loss: 0.0048\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5427448167833578 - normalized_discounted_cumulative_gain@5(0.0): 0.6009002314260246 - mean_average_precision(0.0): 0.558326078481514\n",
      "Epoch 16/20\n",
      "102/102 [==============================] - 6s 59ms/step - loss: 0.0046\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5482189067638121 - normalized_discounted_cumulative_gain@5(0.0): 0.6088392648186538 - mean_average_precision(0.0): 0.5713557393145999\n",
      "Epoch 17/20\n",
      "102/102 [==============================] - 6s 60ms/step - loss: 0.0013\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5482189067638121 - normalized_discounted_cumulative_gain@5(0.0): 0.6088392648186538 - mean_average_precision(0.0): 0.5713557393145999\n",
      "Epoch 17/20\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5382864813405173 - normalized_discounted_cumulative_gain@5(0.0): 0.5989986083327463 - mean_average_precision(0.0): 0.5588807298665582\n",
      "Epoch 18/20\n",
      "102/102 [==============================] - 6s 58ms/step - loss: 0.0033\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5399279919688914 - normalized_discounted_cumulative_gain@5(0.0): 0.602654397161311 - mean_average_precision(0.0): 0.5650284088367032\n",
      "Epoch 19/20\n",
      "102/102 [==============================] - 6s 58ms/step - loss: 0.0043\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.533411266326293 - normalized_discounted_cumulative_gain@5(0.0): 0.6015080030643855 - mean_average_precision(0.0): 0.5549739875689242\n",
      "Epoch 20/20\n",
      "102/102 [==============================] - 6s 60ms/step - loss: 0.0020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5486963704637519 - normalized_discounted_cumulative_gain@5(0.0): 0.605962157344213 - mean_average_precision(0.0): 0.5658641479427621\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{normalized_discounted_cumulative_gain@3(0.0): 0.5486963704637519,\n",
       " normalized_discounted_cumulative_gain@5(0.0): 0.605962157344213,\n",
       " mean_average_precision(0.0): 0.5658641479427621}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(pred_x, pred_y, batch_size=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
